// u(k) = u(k-1) + kp * (e(k) - e(k-1)) + ki * e(k) + kd * (e(k) - 2e(k-1) + e(k-2))
// u(k) = u(k-1) + a0 * e(k) - a1 * e(k-1) + a2 * e(k-2)
// a0 = kp + ki + kd; a1 = kp + 2kd; a2 = kd

module pid(
    input               clk,
    input               rst_n,
    // kp ki kd timecycle
    input   [15 : 0]    kp,
    input   [15 : 0]    ki,
    input   [15 : 0]    kd,
    input   [31 : 0]    ti,
    // data input and output
    input   [15 : 0]    din,
    input   [15 : 0]    dref,
    output  [15 : 0]    dout
);

    parameter s0 = 4'b0000;     parameter s4 = 4'b0100;
    parameter s1 = 4'b0001;     parameter s5 = 4'b0101;
    parameter s2 = 4'b0010;     parameter s6 = 4'b0110;
    parameter s3 = 4'b0011;     parameter s7 = 4'b0111;

    reg             sten;
    reg             creg;
    reg [17 : 0]    areg0;
    reg [17 : 0]    areg1;
    reg [17 : 0]    areg2;
    reg [31 : 0]    count;
    reg [3  : 0 ]   state;
    reg signed  [16 : 0]    ereg0;
    reg signed  [16 : 0]    ereg1;
    reg signed  [16 : 0]    ereg2;
    reg signed  [30 : 0]    mrega;
    reg signed  [30 : 0]    mregs;
    reg signed  [30 : 0]    otreg;

    // 计算PID参数
    always @(posedge clk, negedge rst_n) begin
        if (!rst_n) begin
            creg <= 1'b0;
            sten <= 1'b0;
            count <= 32'b0;
            areg0 <= 18'b0;
            areg1 <= 18'b0;
            areg2 <= 18'b0;
        end
        else begin
            creg <= count[31];
            sten <= count[31] & (~ creg);
            count <= count + ti;
            areg0 <= {2'b0, kp} + {2'b0, ki} + {2'b0, kd};
            areg1 <= {2'b0, kp} + {2'b0, kd} + {2'b0, kd};
            areg2 <= {2'b0, kd};
        end
    end

    assign dout = otreg[15 : 0];

    // 计算PID输出
    always @(posedge clk, negedge rst_n) begin
        if (!rst_n) begin
            ereg0 <= 17'b0;
            ereg1 <= 17'b0;
            ereg2 <= 17'b0;
            mrega <= 31'b0;
            mregs <= 31'b0;
            otreg <= 31'b0;
            state <= s0;
        end
        else begin
            case (state)
                s0 : begin  // u(k) - u(k-1)
                    if (sten) begin
                        ereg0 <= $signed({1'b0, dref}) - $signed({1'b0, din});
                        ereg1 <= ereg0;
                        ereg2 <= ereg1;
                        state <= s1;
                    end
                end
                s1 : begin // a0 * e(k)
                    mrega <= $signed({1'b0, areg0}) * ereg0;
                    state <= s2;
                end
                s2 : begin  // a1 * e(k-1)
                    mrega <= $signed({1'b0, areg1}) * ereg1;
                    mregs <= mrega;
                    state <= s3;
                end
                s3 : begin  // a2 * e(k-2); a0 * e(k) - a1 * e(k-1)
                    mrega <= $signed({1'b0, areg2}) * ereg2;
                    mregs <= mregs - mrega;
                    state <= s4;
                end
                s4 : begin  // a0 * e(k) - a1 * e(k-1) + a2 * e(k-2)
                    mregs <= mregs + mrega;
                    state <= s5;
                end
                s5 : begin  // u(k) = u(k-1) + a0 * e(k) - a1 * e(k-1) + a2 * e(k-2)
                    otreg <= otreg + mregs;
                    state <= s0;
                end
                default : begin
                    state <= s0;
                end
            endcase
        end
    end

endmodule 